Fix CUDA graph parameter grad lifetime#2937
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
4cc8b89 to
6e16c63
Compare
Greptile SummaryThis PR fixes a lifetime bug in CUDA graph replay where parameter gradients returned from
Confidence Score: 5/5Safe to merge — the fix correctly isolates CUDA graph static buffers from downstream autograd users without breaking existing aliasing semantics for delayed-wgrad parameters. The slot offset arithmetic matches how No files require special attention. Important Files Changed
Reviews (5): Last reviewed commit: "Add CUDA graph param grad clone opt-out" | Re-trigger Greptile |
| def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): | ||
| """Return whether a static grad slot is consumed through Graphed.backward.""" | ||
| module_param_start = len(static_grad_inputs) - len(module_params) | ||
| if idx < module_param_start: | ||
| return False | ||
| return not getattr( | ||
| module_params[idx - module_param_start], "skip_backward_post_hook", False | ||
| ) |
There was a problem hiding this comment.
Timing inconsistency between capture and replay attribute reads
_is_returned_param_grad_slot reads skip_backward_post_hook live at both capture time (line 748) and replay time (line 945). If a caller flips the attribute between those two points, the weak-ref decision at capture and the clone decision at replay get out of sync.
Specifically, if the attribute was False at capture (→ static buffer was weak-refed in the _reuse_graph_input_output_buffers path) but True at replay (→ code calls .detach() instead of .detach().clone()), the returned tensor is a detached view of an already-released weak-ref buffer whose memory may have been reused. Snapshotting the skip_backward_post_hook state once at capture time and storing it alongside the static grad slot (or asserting it is unchanged at replay) would make the contract explicit.
There was a problem hiding this comment.
Addressed in 4077b85. The parameter-grad clone policy is now snapshotted at capture time and passed into Graphed.backward, so replay no longer re-reads skip_backward_post_hook. Added test_make_graphed_callables_snapshots_parameter_grad_clone_policy to cover changing the attribute after capture.
| def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers( | ||
| *, | ||
| model_config: str = "small", | ||
| dtype: torch.dtype = torch.float16, | ||
| ) -> None: | ||
| """Test CUDA graphs with reused input/output buffers.""" | ||
| model_config = model_configs[model_config] | ||
| kwargs = dict(model_config=model_config, dtype=dtype) | ||
| outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( | ||
| with_graph=False, | ||
| **kwargs, | ||
| ) | ||
| graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( | ||
| with_graph=True, | ||
| reuse_graph_input_output_buffers=True, | ||
| **kwargs, | ||
| ) | ||
| assert_all_equal(outputs, graph_outputs) |
There was a problem hiding this comment.
Reused-buffer test only validates forward outputs, not gradient correctness
test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.
There was a problem hiding this comment.
Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.
441e419 to
beff9c1
Compare
Signed-off-by: Robin Zhang <robinz@nvidia.com>
beff9c1 to
4077b85
Compare
|
I have a problem with this PR. On one hand I agree that there is a bug here and we should ensure that the gradient buffers are not overwritten before being applied. On the other hand for the cases without gradient accumulation (or when the accumulation is done in a different way, like in Megatron) it is a performance loss. Could we make this behavior optional - have it on by default, but also provide an opt-out option with a clear warning that states where this would be applicable? |
Signed-off-by: Robin Zhang <robinz@nvidia.com>
1a3c2c5 to
c8b2ee3
Compare
|
@ptrendx I added a new argument |
Summary
Fix CUDA graph replay so parameter gradients returned from
Graphed.backwarddo not expose CUDA graph static buffers to downstream autograd users.The fix clones returned parameter gradients before handing them back to autograd, while preserving the existing aliasing behavior for delayed-wgrad parameters marked with
skip_backward_post_hook.Root Cause
When CUDA graph replay returns parameter grad slots directly from static graph buffers, downstream autograd users can retain references to buffers that are overwritten by later graph replays. This can corrupt retained grads or break gradient accumulation semantics.
This is related to PyTorch issue pytorch/pytorch#181723.
Changes
Graphed.backward.